{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recommender Systems with DGL\n",
    "\n",
    "## Introduction\n",
    "\n",
    "Graph Neural Networks (GNN), as a methodology of learning representations on graphs, has gained much attention recently.  Various models such as Graph Convolutional Networks, GraphSAGE, etc. are proposed to obtain representations of whole graphs, or nodes on a single graph.\n",
    "\n",
    "A primary goal of recommendation is to automatically make predictions about a user's interest, e.g. whether/how a user would interact with a set of items, given the interaction history of the user herself, as well as the histories of other users.  The user-item interaction can also be viewed as a bipartite graph, where users and items form two sets of nodes, and edges connecting them stands for interactions.  The problem can then be formulated as a *link-prediction* problem, where we try to predict whether an edge (of a given type) exists between two nodes.\n",
    "\n",
    "Based on this intuition, the academia developed multiple new models for recommendation, including but not limited to:\n",
    "\n",
    "* Geometric Learning Approaches\n",
    "  * [Geometric Matrix Completion](https://papers.nips.cc/paper/5938-collaborative-filtering-with-graph-information-consistency-and-scalable-methods.pdf)\n",
    "  * [Recurrent Multi-graph CNN](https://arxiv.org/pdf/1704.06803.pdf)\n",
    "* Graph-convolutional Approaches\n",
    "  * Models such as [R-GCN](https://arxiv.org/pdf/1703.06103.pdf) or [GraphSAGE](https://github.com/stellargraph/stellargraph/tree/develop/demos/link-prediction/hinsage) also apply.\n",
    "  * [Graph Convolutional Matrix Completion](https://arxiv.org/abs/1706.02263)\n",
    "  * [PinSage](https://arxiv.org/pdf/1806.01973.pdf)\n",
    "  \n",
    "In this hands-on tutorial, we will demonstrate how to write GraphSAGE in DGL + MXNet, and how to apply it in a recommender system setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dgl\n",
    "import dgl.function as FN\n",
    "\n",
    "# Load MXNet as backend\n",
    "dgl.load_backend('mxnet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import ndarray as nd, autograd, gluon\n",
    "from mxnet.gluon import nn\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading data\n",
    "\n",
    "In this tutorial, we focus on rating prediction on MovieLens-100K dataset.  The data comes from [MovieLens](http://files.grouplens.org/datasets/movielens/ml-1m.zip) and is shipped with the notebook already.\n",
    "\n",
    "After loading and train-validation-test-splitting the dataset, we process features into categorical variables (i.e. integers).  We then store them as node features on the graph.\n",
    "\n",
    "Since user features and item features are different, we pad both types of features with zeros.\n",
    "\n",
    "All of the above is encapsulated in `movielens.MovieLens` class for clarity of this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import movielens\n",
    "\n",
    "ml = movielens.MovieLens('ml-100k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = ml.g\n",
    "print('#vertices:', g.number_of_nodes())\n",
    "print('#edges:', g.number_of_edges())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## See the features in the MovieLens dataset\n",
    "\n",
    "The MovieLens dataset has some user features and movie features.\n",
    "\n",
    "User features:\n",
    "* age,\n",
    "* gender,\n",
    "* occupation,\n",
    "* zip code,\n",
    "\n",
    "Movie features:\n",
    "* genre,\n",
    "* year,\n",
    "\n",
    "We use one-hot encoding for \"age\", \"gender\", \"occupation\", \"zip code\" and \"year\". \"genre\" uses multi-hop encoding.\n",
    "\n",
    "In additon, there is a node data \"type\" that indicates the node type in the bipartite graph. Nodes with \"type=1\" are user nodes and \"type=0\" are movie nodes.\n",
    "\n",
    "User nodes don't have features of movie nodes. These features on the user nodes are filled with 0. Similarly, movie nodes don't have features of user nodes and these features are filled with 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(g.ndata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('#user nodes:', mx.nd.sum(g.ndata['type'] == 1).asnumpy())\n",
    "print('#movie nodes:', mx.nd.sum(g.ndata['type'] == 0).asnumpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sampling in DGL\n",
    "\n",
    "When the graph scales up, it's impractical to run graph neural networks on the full graph because the node embeddings couldn't fit in the GPU memory.\n",
    "\n",
    "A natural solution would be partitioning the nodes and computing the embeddings one partition (minibatch) at a time.  The nodes at one convolution layer only depends on their neighbors, rather than all the nodes in the graph, hence reducing the computational cost.  However, if we have multiple layers, and some of the nodes have a lot of neighbors (which is often the case since the degree distribution of many real-world graphs follow [power-law](https://en.wikipedia.org/wiki/Scale-free_network)), computing the embedding of a target node still depends on a large number of nodes in the graph.\n",
    "\n",
    "Please see our [sampling tutorial](https://doc.dgl.ai/tutorials/models/5_giant_graph/1_sampling_mx.html#sphx-glr-tutorials-models-5-giant-graph-1-sampling-mx-py) for details.\n",
    "\n",
    "The data and computation dependency of computing the embedding on target node 1 is illustrated in the figure below:\n",
    "\n",
    "<img src=\"https://s3.us-east-2.amazonaws.com/dgl.ai/amlc_tutorial/Dependency.png\" width=\"400\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Neighbor sampling* is an answer to further reduce the cost of computing node embeddings.  When aggregating messages, instead of collecting from all neighboring nodes, we only collect from some of the randomly-sampled (for instance, uniform sampling at most K neighbors without replacement) neighbors.\n",
    "\n",
    "<img src=\"https://s3.us-east-2.amazonaws.com/dgl.ai/amlc_tutorial/neighbor_sampling.png\" width=\"600\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DGL provides `NodeFlow` that stores the computation dependency of nodes in a graph convolutional network. Below shows hwo we can run GraphSage on `NodeFlow`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recommendation model\n",
    "\n",
    "<img src=\"https://s3.us-east-2.amazonaws.com/dgl.ai/amlc_tutorial/rec_process.png\" width=\"800\">\n",
    "\n",
    "Recommendation with graph neural networks has two steps:\n",
    "* graph encoder: use graph neural networks to compute node embeddings.\n",
    "* edge decoder: compute scores on edges with user embeddings and movie embeddings.\n",
    "\n",
    "The only difference of this class is that it applies on `NodeFlow` instead of `DGLGraph`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNNRecommender(nn.Block):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super(GNNRecommender, self).__init__()\n",
    "        \n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, nf, users, items):\n",
    "        h = self.encoder(nf)\n",
    "        h_users = h[nf.map_from_parent_nid(-1, users, True)]\n",
    "        h_items = h[nf.map_from_parent_nid(-1, items, True)]\n",
    "        return self.decoder(h_users, h_items, users, items)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GraphSage encoder on NodeFlow\n",
    "\n",
    "The encoder does two things:\n",
    "* generate the initial user and movie embeddings,\n",
    "* run GraphSAGE layers on nodes multiple times to compute the final embeddings for rating prediction."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Algorithm\n",
    "\n",
    "The algorithm of a single GraphSAGE layer goes as follows for each node $v$:\n",
    "\n",
    "1. $h_{\\mathcal{N}(v)} \\gets \\mathtt{Sum}_{u \\in \\mathcal{N}(v)} h_{u}$\n",
    "2. $h_{v} \\gets \\sigma\\left(W \\cdot \\mathtt{CONCAT}(h_v, h_{\\mathcal{N}(v)}/d_{\\mathcal{N}(v)})\\right)$\n",
    "3. $h_{v} \\gets h_{v} / \\lVert h_{v} \\rVert_2$\n",
    "\n",
    "### Slight modification on the original GraphSage model\n",
    "\n",
    "In practice, the MovieLens dataset has two types of nodes: users and movies. We need to perform separate node update functions on the two types of nodes.\n",
    "\n",
    "For the movie nodes,\n",
    "\n",
    "$h_{m} \\gets \\sigma\\left(W0 \\cdot \\mathtt{CONCAT}(h_m, h_{\\mathcal{N}(m)} / d_{\\mathcal{N}(m)})\\right)$, \n",
    "$h_{m} \\gets h_{m} / \\lVert h_{m} \\rVert_2$\n",
    "\n",
    "For the user nodes,\n",
    "\n",
    "$h_{u} \\gets \\sigma\\left(W1 \\cdot \\mathtt{CONCAT}(h_u, h_{\\mathcal{N}(u)} / d_{\\mathcal{N}(u)})\\right)$,\n",
    "$h_{u} \\gets h_{u} / \\lVert h_{u} \\rVert_2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GraphSageEncoder(nn.Block):\n",
    "    def __init__(self, embedding_size, n_layers, G):\n",
    "        super(GraphSageEncoder, self).__init__()\n",
    "\n",
    "        self.G = G\n",
    "        self.user_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 1).astype(np.int64)\n",
    "        self.movie_nodes = G.filter_nodes(lambda nodes: nodes.data['type'] == 0).astype(np.int64)\n",
    "\n",
    "        self.user_layers = nn.Sequential()\n",
    "        self.movie_layers = nn.Sequential()\n",
    "        for i in range(n_layers):\n",
    "            self.user_layers.add(GraphSageNodeUpdate(embedding_size))\n",
    "            self.movie_layers.add(GraphSageNodeUpdate(embedding_size))\n",
    "\n",
    "        # One-hot encoding for each node.\n",
    "        node_emb = nn.Embedding(G.number_of_nodes() + 1, embedding_size)\n",
    "        self.user_emb = UserEmbedding(G, embedding_size, node_emb)\n",
    "        self.movie_emb = MovieEmbedding(G, embedding_size, node_emb)\n",
    "\n",
    "    def forward(self, nf):\n",
    "        nf.copy_from_parent(edge_embed_names=None)\n",
    "\n",
    "        # Generate embeddings on user nodes and movie nodes.\n",
    "        for i in range(nf.num_layers):            \n",
    "            layer_nodes = nf.layer_nid(i)\n",
    "            node_type = nf.layers[i].data['type'].asnumpy()\n",
    "            user_nodes = layer_nodes[np.nonzero(node_type == 1)]\n",
    "            movie_nodes = layer_nodes[np.nonzero(node_type == 0)]\n",
    "            if len(user_nodes) > 0:\n",
    "                nf.apply_layer(i, lambda nodes: {'h': self.user_emb(nodes.data, user_nodes)},\n",
    "                               v=user_nodes)\n",
    "            if len(movie_nodes) > 0:\n",
    "                nf.apply_layer(i, lambda nodes: {'h': self.movie_emb(nodes.data, movie_nodes)},\n",
    "                               v=movie_nodes)\n",
    "\n",
    "        # Apply GraphSage layers on NodeFlow.\n",
    "        for i in range(nf.num_blocks):\n",
    "            nf.layers[i+1].data['deg'] = nf.layer_in_degree(i+1).astype(np.float32)\n",
    "            layer_nodes = nf.layer_nid(i+1).asnumpy()\n",
    "            node_type = nf.layers[i+1].data['type'].asnumpy()\n",
    "            user_nodes = layer_nodes[node_type == 1]\n",
    "            movie_nodes = layer_nodes[node_type == 0]\n",
    "            nf.block_compute(i, FN.copy_src('h', 'h'), FN.sum('h', 'h_agg'))\n",
    "            if len(user_nodes) > 0:\n",
    "                nf.apply_layer(i+1, self.user_layers[i], v=user_nodes)\n",
    "            if len(movie_nodes) > 0:\n",
    "                nf.apply_layer(i+1, self.movie_layers[i], v=movie_nodes)\n",
    "\n",
    "        return nf.layers[-1].data['h']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GraphSageNodeUpdate(nn.Block):\n",
    "    def __init__(self, feature_size):\n",
    "        super(GraphSageNodeUpdate, self).__init__()\n",
    "\n",
    "        self.feature_size = feature_size\n",
    "        self.W = nn.Dense(feature_size)\n",
    "        self.leaky_relu = nn.LeakyReLU(0.1)\n",
    "\n",
    "    def forward(self, nodes):\n",
    "        # Node embedding from the previous layer.\n",
    "        h = nodes.data['h']\n",
    "        # Aggregation of the node embeddings in the neighborhood\n",
    "        h_agg = nodes.data['h_agg']\n",
    "        # Degree of the vertex.\n",
    "        deg = nodes.data['deg'].expand_dims(1)\n",
    "        h_concat = nd.concat(h, h_agg / nd.maximum(deg, 1e-6), dim=1)\n",
    "        h_new = self.leaky_relu(self.W(h_concat))\n",
    "        # Layer norm\n",
    "        return {'h': h_new / nd.maximum(h_new.norm(axis=1, keepdims=True), 1e-6)}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute node embeddings on the MovieLens dataset\n",
    "\n",
    "User nodes and movie nodes have different sets of features. Thus, we need to generate embeddings differently.\n",
    "\n",
    "User nodes have categorial features of \"age\", \"gender\", \"occupation\" and \"zip code\". These features are all one-hot encodings. In addition, we add one-hot encoding for every user node. To generate user embedding, we add all of these embeddings together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class UserEmbedding(nn.Block):\n",
    "    def __init__(self, G, feature_size, node_emb):\n",
    "        super(UserEmbedding, self).__init__()\n",
    "\n",
    "        # Embedding matrices for one-hot encoding.\n",
    "        self.emb_age = nn.Embedding(G.ndata['age'].max().asscalar() + 1,\n",
    "                                    feature_size)\n",
    "        self.emb_gender = nn.Embedding(G.ndata['gender'].max().asscalar() + 1,\n",
    "                                       feature_size)\n",
    "        self.emb_occupation = nn.Embedding(G.ndata['occupation'].max().asscalar() + 1,\n",
    "                                           feature_size)\n",
    "        self.emb_zip = nn.Embedding(G.ndata['zip'].max().asscalar() + 1,\n",
    "                                    feature_size)\n",
    "\n",
    "        # One-hot encoding for each node.\n",
    "        self.node_emb = node_emb\n",
    "\n",
    "    def forward(self, ndata, nid):\n",
    "        h = self.node_emb(nid + 1)\n",
    "        extra_repr = []\n",
    "        extra_repr.append(self.emb_age(ndata['age']))\n",
    "        extra_repr.append(self.emb_gender(ndata['gender']))\n",
    "        extra_repr.append(self.emb_occupation(ndata['occupation']))\n",
    "        extra_repr.append(self.emb_zip(ndata['zip']))\n",
    "        return h + nd.stack(*extra_repr, axis=0).sum(axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Movie nodes \"year\", \"genre\". \"year\" is one-hot encoding, \"genre\" is stored in a float32 dense matrix. Like user nodes, we add one-hot encoding to every movie node. To generate movie embedding, we add all of these embeddings together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MovieEmbedding(nn.Block):\n",
    "    def __init__(self, G, feature_size, node_emb):\n",
    "        super(MovieEmbedding, self).__init__()\n",
    "        self.emb_year = nn.Embedding(G.ndata['year'].max().asscalar() + 1,\n",
    "                                     feature_size)\n",
    "\n",
    "        # Linear projection for float32 features.\n",
    "        seq = nn.Sequential()\n",
    "        with seq.name_scope():\n",
    "            seq.add(nn.Dense(feature_size))\n",
    "            seq.add(nn.LeakyReLU(0.1))\n",
    "        self.proj_genre = seq\n",
    "\n",
    "        # One-hot encoding for each node.\n",
    "        self.node_emb = node_emb\n",
    "\n",
    "    def forward(self, ndata, nid):\n",
    "        h = self.node_emb(nid + 1)\n",
    "        extra_repr = []\n",
    "        extra_repr.append(self.emb_year(ndata['year']))\n",
    "        extra_repr.append(self.proj_genre(ndata['genre']))\n",
    "        return h + nd.stack(*extra_repr, axis=0).sum(axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Rating prediction\n",
    "\n",
    "For recommendation, the rating on item $j$ by user $i$ is defined by $u_i^T v_j$.\n",
    "\n",
    "In practice, recommendation models have user bias term and movie bias term: $u_i^T v_j + b_{u_i} + b_{v_j}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DotDecoder(nn.Block):\n",
    "    def __init__(self, num_nodes):\n",
    "        super(DotDecoder, self).__init__()\n",
    "        \n",
    "        with self.name_scope():\n",
    "            self.biases = self.params.get(\n",
    "                'node_biases',\n",
    "                init=mx.init.Zero(),\n",
    "                shape=(num_nodes+1,))\n",
    "            \n",
    "    def forward(self, h_users, h_items, users, items):\n",
    "        user_biases = self.biases.data()[users+1]\n",
    "        item_biases = self.biases.data()[items+1]\n",
    "        return (h_users * h_items).sum(1) + user_biases + item_biases"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Edge sampling for rating prediction\n",
    "\n",
    "For rating prediction, we first need to sample a set of edges. On the endpoint nodes of the edges, we run GraphSage to compute their embeddings. Therefore, we sample edges along with `NodeFlow`s. For each batch of sampled edges, we construct a `NodeFlow` for the endpoint of each side. This is illustrated with the figure below:\n",
    "\n",
    "<img src=\"https://s3.us-east-2.amazonaws.com/dgl.ai/amlc_tutorial/rating_pred.png\" width=\"300\">\n",
    "\n",
    "We can use `NeighborSampler` to implement this edge sampler. When the edge sampler constructs a batch, it creates a `NodeFlow` for the endpoint nodes of both sides as well as endpoint nodes and ratings of the edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1024                       # The number of target nodes in a batch.\n",
    "num_neighbors = 5                       # The number of sampled neighbors on each node\n",
    "num_layers = 1.                         # The number of layers in GraphSage.\n",
    "\n",
    "class EdgeSampler:\n",
    "    def __init__(self, g, src, dst, rating):\n",
    "        shuffle_idx = nd.from_numpy(np.random.permutation(g.number_of_edges()))\n",
    "        src_shuffled = src[shuffle_idx]\n",
    "        dst_shuffled = dst[shuffle_idx]\n",
    "        rating_shuffled = rating[shuffle_idx]\n",
    "        \n",
    "        self.src_batches = []\n",
    "        self.dst_batches = []\n",
    "        self.rating_batches = []\n",
    "        for i in range(0, g.number_of_edges(), batch_size):\n",
    "            j = min(i + batch_size, g.number_of_edges())\n",
    "            self.src_batches.append(src_shuffled[shuffle_idx[i:j]])\n",
    "            self.dst_batches.append(dst_shuffled[shuffle_idx[i:j]])\n",
    "            self.rating_batches.append(rating_shuffled[shuffle_idx[i:j]])\n",
    "\n",
    "        seed_nodes = nd.concat(*sum([[s, d] for s, d in zip(self.src_batches, self.dst_batches)], []), dim=0)\n",
    "        self.sampler = iter(dgl.contrib.sampling.NeighborSampler(\n",
    "            g,                     # the graph\n",
    "            batch_size * 2,        # number of nodes to compute at a time, HACK 2\n",
    "            num_neighbors,         # number of neighbors for each node\n",
    "            num_layers,            # number of layers in GCN\n",
    "            seed_nodes=seed_nodes, # list of seed nodes, HACK 2\n",
    "            prefetch=True,         # whether to prefetch the NodeFlows\n",
    "            add_self_loop=True,    # whether to add a self-loop in the NodeFlows, HACK 1\n",
    "            shuffle=False,         # whether to shuffle the seed nodes.  Should be False here.\n",
    "            num_workers=4,\n",
    "        ))\n",
    "        self.i = 0\n",
    "        \n",
    "    def __iter__(self):\n",
    "        return self\n",
    "    \n",
    "    def __next__(self):\n",
    "        if self.i == len(self.src_batches):\n",
    "            raise StopIteration\n",
    "            \n",
    "        idx = self.i\n",
    "        self.i += 1\n",
    "        return (self.src_batches[idx], self.dst_batches[idx],\n",
    "                self.rating_batches[idx], next(self.sampler))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the training set\n",
    "\n",
    "We use 80% of edges for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_train = ml.g_train\n",
    "rating_train = g_train.edata['rating']\n",
    "src_train, dst_train = g_train.all_edges()\n",
    "print('#vertices:', g_train.number_of_nodes())\n",
    "print('#training edges:', g_train.number_of_edges())\n",
    "print('percentage:', g_train.number_of_edges() / g.number_of_edges() * 100, \"%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the testing edge set\n",
    "\n",
    "We use 20% of edges for evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = ml.g\n",
    "eid_test = g.filter_edges(lambda edges: edges.data['test']).astype('int64')\n",
    "src_test, dst_test = g.find_edges(eid_test)\n",
    "rating_test = g.edges[eid_test].data['rating']\n",
    "print('#testing edges:', len(eid_test))\n",
    "print('testing edges:', len(eid_test) / g.number_of_edges() * 100, \"%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(g_train, src, dst, batch_size):\n",
    "    # Training\n",
    "    tot_loss = 0\n",
    "    num_batches = 0\n",
    "    for s, d, r, nodeflow in EdgeSampler(g_train, src, dst, rating_train):\n",
    "        with mx.autograd.record():\n",
    "            score = model.forward(nodeflow, s, d)\n",
    "            loss = ((score - r) ** 2).mean()\n",
    "            loss.backward()\n",
    "        trainer.step(1)\n",
    "        tot_loss += loss.asscalar()\n",
    "        num_batches += 1\n",
    "        \n",
    "    # Return the training loss\n",
    "    return tot_loss / num_batches"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The testing code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(g, batch_size):\n",
    "    # Validation & Test, we precompute GraphSage output for all nodes first.\n",
    "    sampler = dgl.contrib.sampling.NeighborSampler(\n",
    "        g,\n",
    "        batch_size,\n",
    "        num_neighbors,\n",
    "        num_layers,\n",
    "        seed_nodes=nd.arange(g.number_of_nodes()).astype('int64'),\n",
    "        prefetch=True,\n",
    "        add_self_loop=True,\n",
    "        shuffle=False,\n",
    "        num_workers=4\n",
    "    )\n",
    "\n",
    "    h = []\n",
    "    for nf in sampler:\n",
    "        h.append(model.encoder(nf))\n",
    "    h = nd.concat(*h, dim=0)\n",
    "\n",
    "    # Compute test RMSE\n",
    "    score = model.decoder(h[src_test], h[dst_test], src_test, dst_test)\n",
    "    test_rmse = nd.sqrt(((score - rating_test) ** 2).mean())\n",
    "    \n",
    "    return test_rmse.asscalar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model = GNNRecommender(GraphSageEncoder(100, 1, g_train),\n",
    "                       DotDecoder(g_train.number_of_nodes()))\n",
    "model.collect_params().initialize(ctx=mx.cpu())\n",
    "trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': 0.001, 'wd': 1e-9})\n",
    "\n",
    "for epoch in range(50):\n",
    "    loss = train(g_train, src_train, dst_train, batch_size)\n",
    "    test_rmse = test(g, batch_size)\n",
    "    print('Training loss:', loss, 'Test RMSE:', test_rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
